{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Extending PyTorch differentiable functions\n",
    "\n",
    "In this notebook you'll see how to add your custom differentiable function for which you need to specify `forward` and `backward` passes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import some libraries\n",
    "import torch\n",
    "import numpy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For a gentle introduction see [PyTorch extension](https://pytorch.org/docs/stable/notes/extending.html) tutorial.\n",
    "\n",
    "Source for `torch.autograd.Function` available [here](https://github.com/pytorch/pytorch/blob/master/torch/autograd/function.py).\n",
    "These are the two that we have to override:\n",
    "\n",
    "```python\n",
    "@staticmethod\n",
    "def forward(ctx, *args, **kwargs):\n",
    "    \"\"\"Performs the operation.\n",
    "    This function is to be overridden by all subclasses.\n",
    "    It must accept a context ctx as the first argument, followed by any\n",
    "    number of arguments (tensors or other types).\n",
    "    The context can be used to store tensors that can be then retrieved\n",
    "    during the backward pass.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "\n",
    "@staticmethod\n",
    "def backward(ctx, *grad_outputs):\n",
    "    \"\"\"Defines a formula for differentiating the operation.\n",
    "    This function is to be overridden by all subclasses.\n",
    "    It must accept a context :attr:`ctx` as the first argument, followed by\n",
    "    as many outputs did :func:`forward` return, and it should return as many\n",
    "    tensors, as there were inputs to :func:`forward`. Each argument is the\n",
    "    gradient w.r.t the given output, and each returned value should be the\n",
    "    gradient w.r.t. the corresponding input.\n",
    "    The context can be used to retrieve tensors saved during the forward\n",
    "    pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple\n",
    "    of booleans representing whether each input needs gradient. E.g.,\n",
    "    :func:`backward` will have ``ctx.needs_input_grad[0] = True`` if the\n",
    "    first input to :func:`forward` needs gradient computated w.r.t. the\n",
    "    output.\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "```    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Custom addition module\n",
    "class MyAdd(torch.autograd.Function):\n",
    "\n",
    "    @staticmethod\n",
    "    def forward(ctx, x1, x2):\n",
    "        # ctx is a context where we can save\n",
    "        # computations for backward.\n",
    "        ctx.save_for_backward(x1, x2)\n",
    "        return x1 + x2\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_output):\n",
    "        x1, x2 = ctx.saved_tensors\n",
    "        grad_x1 = grad_output * torch.ones_like(x1)\n",
    "        grad_x2 = grad_output * torch.ones_like(x2)\n",
    "        # need to return grads in order \n",
    "        # of inputs to forward (excluding ctx)\n",
    "        return grad_x1, grad_x2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x1: tensor([-0.0141, -2.2316,  0.0897], requires_grad=True)\n",
      "x2: tensor([ 0.1969, -0.2429, -0.4893], requires_grad=True)\n",
      " y: tensor([ 0.1828, -2.4745, -0.3997], grad_fn=<MyAddBackward>)\n",
      " z: -0.8971386551856995, z.grad_fn: <MeanBackward0 object at 0x103338ee0>\n",
      "x1.grad: tensor([0.3333, 0.3333, 0.3333])\n",
      "x2.grad: tensor([0.3333, 0.3333, 0.3333])\n"
     ]
    }
   ],
   "source": [
    "# Let's try out the addition module\n",
    "x1 = torch.randn((3), requires_grad=True)\n",
    "x2 = torch.randn((3), requires_grad=True)\n",
    "print(f'x1: {x1}')\n",
    "print(f'x2: {x2}')\n",
    "myadd = MyAdd.apply  # aliasing the apply method\n",
    "y = myadd(x1, x2)\n",
    "print(f' y: {y}')\n",
    "z = y.mean()\n",
    "print(f' z: {z}, z.grad_fn: {z.grad_fn}')\n",
    "z.backward()\n",
    "print(f'x1.grad: {x1.grad}')\n",
    "print(f'x2.grad: {x2.grad}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Custom split module\n",
    "class MySplit(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, x):\n",
    "        ctx.save_for_backward(x)\n",
    "        x1 = x.clone()\n",
    "        x2 = x.clone()\n",
    "        return x1, x2\n",
    "        \n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_x1, grad_x2):\n",
    "        x = ctx.saved_tensors[0]\n",
    "        print(f'grad_x1: {grad_x1}')\n",
    "        print(f'grad_x2: {grad_x2}')\n",
    "        return grad_x1 + grad_x2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " x: tensor([-1.7785,  0.6111,  0.1653, -0.8121], requires_grad=True)\n",
      "x1: tensor([-1.7785,  0.6111,  0.1653, -0.8121], grad_fn=<MySplitBackward>)\n",
      "x2: tensor([-1.7785,  0.6111,  0.1653, -0.8121], grad_fn=<MySplitBackward>)\n",
      " y: tensor([-3.5571,  1.2221,  0.3305, -1.6241], grad_fn=<AddBackward0>)\n",
      " z: -0.9071540236473083, z.grad_fn: <MeanBackward0 object at 0x1215816d0>\n",
      "grad_x1: tensor([0.2500, 0.2500, 0.2500, 0.2500])\n",
      "grad_x2: tensor([0.2500, 0.2500, 0.2500, 0.2500])\n",
      " x.grad: tensor([0.5000, 0.5000, 0.5000, 0.5000])\n"
     ]
    }
   ],
   "source": [
    "# Let's try out the split module\n",
    "x = torch.randn((4), requires_grad=True)\n",
    "print(f' x: {x}')\n",
    "split = MySplit.apply\n",
    "x1, x2 = split(x)\n",
    "print(f'x1: {x1}')\n",
    "print(f'x2: {x2}')\n",
    "y = x1 + x2\n",
    "print(f' y: {y}')\n",
    "z = y.mean()\n",
    "print(f' z: {z}, z.grad_fn: {z.grad_fn}')\n",
    "z.backward()\n",
    "print(f' x.grad: {x.grad}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Custom max module\n",
    "class MyMax(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, x):\n",
    "        # example where we explicitly use non-torch code\n",
    "        maximum = x.detach().numpy().max()\n",
    "        argmax = x.detach().eq(maximum).float()\n",
    "        ctx.save_for_backward(argmax)\n",
    "        return torch.tensor(maximum)\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_output):\n",
    "        argmax = ctx.saved_tensors[0]\n",
    "        return grad_output * argmax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x: tensor([ 0.4467, -1.4328, -1.7784, -0.1107, -0.8490], requires_grad=True)\n",
      "y: 0.4466537535190582, y.grad_fn: <torch.autograd.function.MyMaxBackward object at 0x1032c0ad0>\n",
      "x.grad: tensor([1., 0., 0., 0., 0.])\n"
     ]
    }
   ],
   "source": [
    "# Let's try out the max module\n",
    "x = torch.randn((5), requires_grad=True)\n",
    "print(f'x: {x}')\n",
    "mymax = MyMax.apply\n",
    "y = mymax(x)\n",
    "print(f'y: {y}, y.grad_fn: {y.grad_fn}')\n",
    "y.backward()\n",
    "print(f'x.grad: {x.grad}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pDL] *",
   "language": "python",
   "name": "conda-env-pDL-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
